{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fe5f5ee6",
   "metadata": {
    "id": "TKvYiJgnYExi"
   },
   "source": [
    "This notebook provides examples to go along with the [textbook](http://manipulation.csail.mit.edu/clutter.html).  I recommend having both windows open, side-by-side!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6122be6d",
   "metadata": {
    "id": "A4QOaw_zYLfI",
    "lines_to_end_of_cell_marker": 2
   },
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "import numpy as np\n",
    "from pydrake.all import (\n",
    "    PointCloud,\n",
    "    Rgba,\n",
    "    RigidTransform,\n",
    "    RotationMatrix,\n",
    "    Sphere,\n",
    "    StartMeshcat,\n",
    ")\n",
    "from scipy.spatial import KDTree\n",
    "\n",
    "from manipulation import running_as_notebook\n",
    "from manipulation.meshcat_utils import AddMeshcatTriad\n",
    "from manipulation.mustard_depth_camera_example import MustardExampleSystem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae062858",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Start the visualizer.\n",
    "meshcat = StartMeshcat()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e81a649c",
   "metadata": {},
   "source": [
    "# Estimating normals (and local curvature)\n",
    "\n",
    "TODO: Add the version from depth images (nearest pixels instead of nearest neighbors), and implement it in DepthImageToPointCloud."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2357a8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def normal_estimation():\n",
    "    system = MustardExampleSystem()\n",
    "    context = system.CreateDefaultContext()\n",
    "\n",
    "    meshcat.Delete()\n",
    "    meshcat.DeleteAddedControls()\n",
    "    meshcat.SetProperty(\"/Background\", \"visible\", False)\n",
    "\n",
    "    point_cloud = system.GetOutputPort(\"camera0_point_cloud\").Eval(context)\n",
    "    cloud = point_cloud.Crop(\n",
    "        lower_xyz=[-0.3, -0.3, -0.3], upper_xyz=[0.3, 0.3, 0.3]\n",
    "    )\n",
    "    meshcat.SetObject(\"point_cloud\", cloud)\n",
    "\n",
    "    # Extract camera position\n",
    "    plant = system.GetSubsystemByName(\"plant\")\n",
    "    p_WC = (\n",
    "        plant.GetFrameByName(\"camera0_origin\")\n",
    "        .CalcPoseInWorld(plant.GetMyContextFromRoot(context))\n",
    "        .translation()\n",
    "    )\n",
    "\n",
    "    kdtree = KDTree(cloud.xyzs().T)\n",
    "\n",
    "    num_closest = 40\n",
    "    neighbors = PointCloud(num_closest)\n",
    "    AddMeshcatTriad(meshcat, \"least_squares_basis\", length=0.03, radius=0.0005)\n",
    "\n",
    "    meshcat.AddSlider(\n",
    "        \"point\",\n",
    "        min=0,\n",
    "        max=cloud.size() - 1,\n",
    "        step=1,\n",
    "        value=429,  # 4165,\n",
    "        decrement_keycode=\"ArrowLeft\",\n",
    "        increment_keycode=\"ArrowRight\",\n",
    "    )\n",
    "    meshcat.AddButton(\"Stop Normal Estimation\", \"Escape\")\n",
    "    print(\n",
    "        \"Press ESC or the 'Stop Normal Estimation' button in Meshcat to continue\"\n",
    "    )\n",
    "    last_index = -1\n",
    "    while meshcat.GetButtonClicks(\"Stop Normal Estimation\") < 1:\n",
    "        index = round(meshcat.GetSliderValue(\"point\"))\n",
    "        if index == last_index:\n",
    "            time.sleep(0.1)\n",
    "            continue\n",
    "        last_index = index\n",
    "\n",
    "        query = cloud.xyz(index)\n",
    "        meshcat.SetObject(\"query\", Sphere(0.001), Rgba(0, 1, 0))\n",
    "        meshcat.SetTransform(\"query\", RigidTransform(query))\n",
    "        (distances, indices) = kdtree.query(\n",
    "            query, k=num_closest, distance_upper_bound=0.1\n",
    "        )\n",
    "\n",
    "        neighbors.resize(len(distances))\n",
    "        neighbors.mutable_xyzs()[:] = cloud.xyzs()[:, indices]\n",
    "\n",
    "        meshcat.SetObject(\n",
    "            \"neighbors\", neighbors, rgba=Rgba(0, 0, 1), point_size=0.001\n",
    "        )\n",
    "\n",
    "        neighbor_pts = neighbors.xyzs().T\n",
    "        pstar = np.mean(neighbor_pts, axis=0)\n",
    "        prel = neighbor_pts - pstar\n",
    "        W = np.matmul(prel.T, prel)\n",
    "        w, V = np.linalg.eigh(W)\n",
    "        # V[:, 0] corresponds to the smallest eigenvalue, and V[:, 2] to the\n",
    "        # largest.\n",
    "        R = np.fliplr(V)\n",
    "        # R[:, 0] corresponds to the largest eigenvalue, and R[:, 2] to the\n",
    "        # smallest (the normal).\n",
    "\n",
    "        # Handle improper rotations\n",
    "        R = R @ np.diag([1, 1, np.linalg.det(R)])\n",
    "\n",
    "        # If the normal is not pointing towards the camera...\n",
    "        if (p_WC - -query).dot(R[:, 2]) < 0:\n",
    "            # then flip the y and z axes.\n",
    "            R = R @ np.diag([1, -1, -1])\n",
    "\n",
    "        meshcat.SetTransform(\n",
    "            \"least_squares_basis\", RigidTransform(RotationMatrix(R), query)\n",
    "        )\n",
    "\n",
    "        if not running_as_notebook:\n",
    "            break\n",
    "\n",
    "    meshcat.DeleteAddedControls()\n",
    "\n",
    "\n",
    "normal_estimation()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
